{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z8fQ_7Vn58mG"
      },
      "source": [
        "## Example Checkpointed Model Use\n",
        "\n",
        "Licensed under the Apache License, Version 2.0."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-cHvGcYI6DHP"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oVbWQLtu6EIU"
      },
      "outputs": [],
      "source": [
        "ckpt_path_template = 'gs://gresearch/reliable-deep-learning/checkpoints/baselines/diabetic_retinopathy_detection/deterministic/keras_model_step_63_{}/'\n",
        "ckpt_paths = [ckpt_path_template.format(i) for i in range(10)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tEweaWoQECvL"
      },
      "outputs": [],
      "source": [
        "models = [tf.keras.models.load_model(path) for path in ckpt_paths]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cFeC3CpQ6HGN"
      },
      "outputs": [],
      "source": [
        "images = tf.random.normal((5, 512, 512, 3))  # Replace with your batch of images.\n",
        "logits_list = [model(images) for model in models]\n",
        "logits_list = tf.stack(logits_list, axis=0)\n",
        "probs_list = tf.nn.sigmoid(logits_list)\n",
        "probs = tf.reduce_mean(probs_list, axis=0)\n",
        "print(probs)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Example Checkpointed Model Use",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1ZtUX85rUJsV0Bnbxi2-Htn6BlEmwEwm5",
          "timestamp": 1624038085156
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
